import matplotlib.pyplot as plt
from scipy.io import loadmat
import spectral as spy
import numpy as np


# def unsupervised_demo(src):  # K-means法 迭代方法生成聚类
#     m, c = spy.kmeans(src, nclusters=6, max_iterations=30)  # 分为6类，最大迭代30次
#     spy.imshow(classes=m)  # 显示分类结果
#     plt.figure()
#     for i in range(c.shape[0]):  # 显示分类后的各光谱曲线
#         plt.plot(c[i])
#     plt.pause(60)
#
#
# def supervised_demo(src,gt):
#     classes = spy.create_training_classes(src, gt)  # 创建训练类集合
#     gmlc = spy.GaussianClassifier(classes)  # 高斯的最大似然分类法
#     clmap = gmlc.classify_image(src)
#     spy.imshow(classes=clmap)
#     gt_results = clmap * (gt != 0)  # 为分好类的图像设定一个Mask
#     gt_right = gt_results * (gt_results == gt)
#     gt_errors = gt_results * (gt_results != gt)
#     spy.imshow(classes=gt_right, title="right")  # 分类正确的部分
#     spy.imshow(classes=gt_errors, title="errors")  # 分类错误的部分
#     precision_evaluation(gt_results ,gt)  # 精度评定
#     plt.pause(60)
#
#
# def precision_evaluation(cla, gt):  # 精度评定
#     def count_number(src):  # 统计分类数据
#         dict_k = {}
#         for row in range(src.shape[0]):
#             for col in range(src.shape[1]):
#                 if src[row][col] not in dict_k:
#                     dict_k[src[row][col]] = 0
#                 dict_k[src[row][col]] += 1
#         dict_k = dict(sorted(dict_k.items()))
#         del dict_k[0]  # 键为0的是未归类的部分,所以去掉
#         class_sum = sum(dict_k.values())
#         return dict_k, class_sum
#
#     cla_dic, cla_sum = count_number(cla)  # 分类后的
#     gt_dic, gt_sum = count_number(gt)  # 真实的
#     gt_right = cla * (cla == gt)
#     gt_right_dic, gt_right_sum = count_number(gt_right)  # 分类正确的
#
#     p0 = gt_right_sum / gt_sum
#     pe = 0
#
#     for gt_key in gt_dic:
#         if gt_key not in cla_dic:
#             cla_dic[gt_key] = 0
#             gt_right_dic[gt_key] = 0
#             print("类别%s的用户精度为：0.0000,生产者精度为：0.0000" % gt_key)
#         else:
#             print("类别%s的用户精度为：%.4f," % (gt_key, gt_right_dic[gt_key] / cla_dic[gt_key]), end='')
#             print("生产者精度为：%.4f" % (gt_right_dic[gt_key] / gt_dic[gt_key]))
#         pe += gt_dic[gt_key] * cla_dic[gt_key]
#
#     pe = pe / (gt_sum * gt_sum)
#     kappa = (p0 - pe) / (1 - pe)
#     overall_accuracy = gt_right_sum / gt_sum
#     print("-" * 36)
#     print("Kappa=", kappa)
#     print("overall_accuracy", overall_accuracy)


# # 加载mat格式的数据。loadmat输出的是dict，所以需要进行定位
input_image = loadmat('C:/Users/Cojin/Desktop/软件工程原理与实践/期末项目/数据集/Indian_pines_corrected .mat')['indian_pines_corrected']
# gt = loadmat('C:/Users/Cojin/Desktop/软件工程原理与实践/期末项目/数据集/Indian_pines_gt.mat')['indian_pines_gt']  # 加载真实类别
# 加载mat格式的数据。loadmat输出的是dict，所以需要进行定位

# input_image = loadmat('C:/Users/Cojin/Desktop/软件工程原理与实践/期末项目/数据集/SalinasA_corrected.mat')['salinasA_corrected']

# 可视化影像
view = spy.imshow(data=input_image, bands=[69, 27, 11], figsize=(6, 6))
plt.pause(60)
# # 分类
# unsupervised_demo(input_image)
# supervised_demo(input_image, gt)